{
  "nbformat": 4,
  "nbformat_minor": 5,
  "metadata": {
    "colab": {
      "name": "examples_te.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.10"
    },
    "toc": {
      "base_numbering": 1,
      "nav_menu": {},
      "number_sections": true,
      "sideBar": true,
      "skip_h1_title": false,
      "title_cell": "Table of Contents",
      "title_sidebar": "Contents",
      "toc_cell": false,
      "toc_position": {},
      "toc_section_display": true,
      "toc_window_display": false
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nasty-intention"
      },
      "source": [
        "# Colab Demo"
      ],
      "id": "nasty-intention"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "previous-bacon"
      },
      "source": [
        "## Dependencies and Imports"
      ],
      "id": "previous-bacon"
    },
    {
      "cell_type": "code",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2021-06-18T06:34:38.757647Z",
          "start_time": "2021-06-18T06:34:37.545968Z"
        },
        "id": "complicated-receiver",
        "cellView": "form"
      },
      "source": [
        "#@title Install dependencies\n",
        "\n",
        "import os\n",
        "import yaml\n",
        "import torch\n",
        "from torch import package\n",
        "\n",
        "torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',\n",
        "                               'latest_silero_models.yml',\n",
        "                               progress=False)\n",
        "\n",
        "with open('latest_silero_models.yml', 'r') as yaml_file:\n",
        "    models = yaml.load(yaml_file, Loader=yaml.SafeLoader)\n",
        "model_conf = models.get('te_models').get('latest')"
      ],
      "id": "complicated-receiver",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2021-06-18T06:34:47.677954Z",
          "start_time": "2021-06-18T06:34:47.665753Z"
        },
        "id": "pacific-injury"
      },
      "source": [
        "# see avaiable languages\n",
        "available_languages = list(model_conf.get('languages'))\n",
        "print(f'Available languages {available_languages}')\n",
        "\n",
        "# and punctuation marks\n",
        "available_punct = list(model_conf.get('punct'))\n",
        "print(f'Available punctuation marks {available_punct}')"
      ],
      "id": "pacific-injury",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3987bc23"
      },
      "source": [
        "## With language label"
      ],
      "id": "3987bc23"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "c8f886e8"
      },
      "source": [
        "model_url = model_conf.get('package')\n",
        "\n",
        "model_dir = \"downloaded_model\"\n",
        "os.makedirs(model_dir, exist_ok=True)\n",
        "model_path = os.path.join(model_dir, os.path.basename(model_url))\n",
        "\n",
        "if not os.path.isfile(model_path):\n",
        "    torch.hub.download_url_to_file(model_url,\n",
        "                                   model_path,\n",
        "                                   progress=True)\n",
        "\n",
        "imp = package.PackageImporter(model_path)\n",
        "model = imp.load_pickle(\"te_model\", \"model\")\n",
        "example_texts = model.examples\n",
        "\n",
        "def apply_te(text, lan='en'):\n",
        "    return model.enhance_text(text, lan)"
      ],
      "id": "c8f886e8",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BzyUrUR0qqri"
      },
      "source": [
        "#### Example input"
      ],
      "id": "BzyUrUR0qqri"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ebe40114"
      },
      "source": [
        "input_text = model.examples[0]\n",
        "output_text = apply_te(input_text, lan='en')\n",
        "print(f\"Input: \\n{input_text}\\nOutput:\\n{output_text}\") "
      ],
      "id": "ebe40114",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kwCu8rpqq9E2"
      },
      "source": [
        "#### Input from the keyboard"
      ],
      "id": "kwCu8rpqq9E2"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MkrEWElqqiOz"
      },
      "source": [
        "input_text = input('Enter input text\\n')\n",
        "print(apply_te(input_text, lan='en'))"
      ],
      "id": "MkrEWElqqiOz",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fae7bbef"
      },
      "source": [
        "### With fasttext for language detection"
      ],
      "id": "fae7bbef"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FP4Gny82uy2C"
      },
      "source": [
        "! pip install fasttext-langdetect\n",
        "! pip install wget"
      ],
      "id": "FP4Gny82uy2C",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7bdb5e65"
      },
      "source": [
        "from ftlangdetect import detect\n",
        "\n",
        "input_text = input('Enter input text\\n')\n",
        "lan = detect(text=input_text, low_memory=False)['lang']\n",
        "print(f\"Detected language: {lan}\")\n",
        "print(apply_te(input_text, lan=lan))"
      ],
      "id": "7bdb5e65",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2021-03-29T06:46:39.993648Z",
          "start_time": "2021-03-29T06:46:39.052349Z"
        },
        "id": "stupid-naples"
      },
      "source": [
        "import torch\n",
        "\n",
        "model, example_texts, languages, punct, apply_te = torch.hub.load(repo_or_dir='snakers4/silero-models',\n",
        "                                                                  model='silero_te')"
      ],
      "id": "stupid-naples",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "indirect-berry"
      },
      "source": [
        "input_text = input('Enter input text\\n')\n",
        "apply_te(input_text, lan='en')"
      ],
      "id": "indirect-berry",
      "execution_count": null,
      "outputs": []
    }
  ]
}